{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Parsing.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/Unsupervised_CFG.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4ifr3ACNQ0SO",
        "colab_type": "code",
        "outputId": "f4551d3b-c9e1-429a-e27c-935f090861a3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "!pip install -qqq torchtext -qqq pytorch-transformers dgl\n",
        "!pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fDSxxDUXTNuV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torchtext\n",
        "import torch\n",
        "from torch_struct import SentCFG\n",
        "from torch_struct.networks import NeuralCFG\n",
        "import torch_struct.data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FE0bxak5TWRQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Download and the load default data.\n",
        "WORD = torchtext.data.Field(include_lengths=True)\n",
        "UD_TAG = torchtext.data.Field(init_token=\"<bos>\", eos_token=\"<eos>\", include_lengths=True)\n",
        "\n",
        "# Download and the load default data.\n",
        "train, val, test = torchtext.datasets.UDPOS.splits(\n",
        "    fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), \n",
        "    filter_pred=lambda ex: 5 < len(ex.word) < 30\n",
        ")\n",
        "\n",
        "WORD.build_vocab(train.word, min_freq=3)\n",
        "UD_TAG.build_vocab(train.udtag)\n",
        "train_iter = torch_struct.data.TokenBucket(train, \n",
        "    batch_size=200,\n",
        "    device=\"cuda:0\")\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "id3sRBHuUYIL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "H = 256\n",
        "T = 30\n",
        "NT = 30\n",
        "model = NeuralCFG(len(WORD.vocab), T, NT, H)\n",
        "model.cuda()\n",
        "opt = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.75, 0.999])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NLSnd7cyUH_3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train():\n",
        "    #model.train()\n",
        "    losses = []\n",
        "    for epoch in range(10):\n",
        "        for i, ex in enumerate(train_iter):\n",
        "            opt.zero_grad()\n",
        "            words, lengths = ex.word \n",
        "            N, batch = words.shape\n",
        "            words = words.long()\n",
        "            params = model(words.cuda().transpose(0, 1))\n",
        "            dist = SentCFG(params, lengths=lengths)\n",
        "            loss = dist.partition.mean()\n",
        "            (-loss).backward()\n",
        "            losses.append(loss.detach())\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)\n",
        "            opt.step()\n",
        "\n",
        "            if i % 100 == 1:            \n",
        "                print(-torch.tensor(losses).mean(), words.shape)\n",
        "                losses = []"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "__FoQZyRZ41m",
        "colab_type": "code",
        "outputId": "5387a40b-77c8-4e09-b207-8819e164b59c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "train()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor(52.3219) torch.Size([29, 18])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9XJL_a9KbUdc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "for i, ex in enumerate(train_iter):\n",
        "    opt.zero_grad()\n",
        "    words, lengths = ex.word \n",
        "\n",
        "    N, batch = words.shape\n",
        "    words = words.long()\n",
        "    params = terms(words.transpose(0, 1)), rules(batch), roots(batch)\n",
        "\n",
        "    tree = CKY(MaxSemiring).marginals(params, lengths=lengths, _autograd=True)\n",
        "    print(tree)\n",
        "    break"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "swe-bvF5KtsI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def split(spans):\n",
        "    batch, N = spans.shape[:2]\n",
        "    splits = []\n",
        "    for b in range(batch):\n",
        "        cover = spans[b].nonzero()\n",
        "        left = {i: [] for i in range(N)}\n",
        "        right = {i: [] for i in range(N)}\n",
        "        batch_split = {}\n",
        "        for i in range(cover.shape[0]):\n",
        "            i, j, A = cover[i].tolist()\n",
        "            left[i].append((A, j, j - i + 1))\n",
        "            right[j].append((A, i, j - i + 1))\n",
        "        for i in range(cover.shape[0]):\n",
        "            i, j, A = cover[i].tolist()\n",
        "            B = None\n",
        "            for B_p, k, a_span in left[i]:\n",
        "                for C_p, k_2, b_span in right[j]:\n",
        "                    if k_2 == k + 1 and a_span + b_span == j - i + 1:\n",
        "                        B, C = B_p, C_p\n",
        "                        k_final = k\n",
        "                        break\n",
        "            if j > i:\n",
        "                batch_split[(i, j)] =k\n",
        "        splits.append(batch_split)\n",
        "    return splits "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jyDKutsIK3aw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "splits = split(spans)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}